# Portions of this file are adapted from
#   SAM (https://github.com/davda54/sam, MIT License).
# Changes: refactored for IAM algorithm, etc.

import torch
from torchvision import datasets, transforms
from torchvision.transforms import AutoAugment, AutoAugmentPolicy
from torch.utils.data import DataLoader, Subset
import random
import numpy as np
from PIL import Image

dataset_dir = '/home/dataset/'
# dataset_dir = '../data/'


def get_datasets(dataset_name='cifar10'):
    if dataset_name == 'cifar10':
        return get_cifar10_dataset()
    elif dataset_name == 'cifar100':
        return get_cifar100_dataset()
    else:
        raise ValueError(f"Unknown dataset: {dataset_name}")

class Cutout:
    def __init__(self, size=16, p=0.5):
        self.size = size
        self.half_size = size // 2
        self.p = p

    def __call__(self, image):
        if torch.rand(1).item() > self.p:
            return image

        c, h, w = image.shape
        left = torch.randint(-self.half_size, w - self.half_size, [1]).item()
        top = torch.randint(-self.half_size, h - self.half_size, [1]).item()
        right = min(w, left + self.size)
        bottom = min(h, top + self.size)

        image[:, max(0, left): right, max(0, top): bottom] = 0
        return image
    
class TransformFixMatch:
    def __init__(self, mean, std):
        self.weak = transforms.Compose([
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomCrop(32, padding=4, padding_mode='reflect'),
            transforms.ToTensor(),
            transforms.Normalize(mean, std)
        ])
        
        self.strong = transforms.Compose([
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomCrop(32, padding=4, padding_mode='reflect'),
            RandAugment(n=2, m=10),
            transforms.ToTensor(),
            transforms.Normalize(mean, std),
            transforms.RandomErasing(p=0.5)  # Cutout 효과
        ])

    def __call__(self, x):
        weak = self.weak(x)
        strong = self.strong(x)
        return weak, strong

class CIFAR10SSL:
    def __init__(self, data, targets, idx, transform=None):
        self.data = data[idx]
        self.targets = targets[idx]
        self.transform = transform
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        img, target = self.data[index], self.targets[index]
        img = Image.fromarray(img)
        
        if self.transform is not None:
            if isinstance(self.transform, TransformFixMatch):
                # FixMatch: 약한/강한 증강 쌍 반환
                return self.transform(img), target
            else:
                # 일반 변환
                return self.transform(img), target
        return img, target

class RandAugment:
    def __init__(self, n=2, m=10):
        self.n = n
        self.m = m
        self.augment_pool = [
            transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
            transforms.RandomPerspective(distortion_scale=0.5, p=0.5),
            transforms.RandomRotation(degrees=15),
            transforms.RandomSolarize(threshold=128),
            transforms.RandomAutocontrast(),
            transforms.RandomAdjustSharpness(sharpness_factor=2),
            transforms.RandomEqualize()
        ]
    
    def __call__(self, img):
        ops = np.random.choice(self.augment_pool, self.n)
        for op in ops:
            img = op(img)
        return img


def get_cifar10_loaders(batch_size=128, num_workers=4, model_type='WRN', autoaugment = False):

    input_size = 224 if model_type == "ViT" else 32
    
    if autoaugment:
        transform_train = transforms.Compose([
            transforms.Resize(input_size),
            transforms.RandomCrop(input_size, padding=input_size//8),
            transforms.RandomHorizontalFlip(),
            AutoAugment(policy=AutoAugmentPolicy.CIFAR10),
            transforms.ToTensor(),
            # transforms.Normalize(mean, std),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
            Cutout(),
        ])
    else: 
        transform_train = transforms.Compose([
            transforms.Resize(input_size),
            transforms.RandomCrop(input_size, padding=input_size//8),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            # transforms.Normalize(mean, std),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])

    transform_test = transforms.Compose([
        transforms.Resize(input_size),
        transforms.ToTensor(),
        # transforms.Normalize(mean, std)
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    train_dataset = datasets.CIFAR10(root=dataset_dir, train=True, download=True, transform=transform_train)
    test_dataset = datasets.CIFAR10(root=dataset_dir, train=False, download=True, transform=transform_test)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True, persistent_workers=True, prefetch_factor=2)
    test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False, num_workers=num_workers)

    return train_loader, test_loader

def get_cifar100_loaders(batch_size=128, num_workers=4, autoaugment = False):
    cifar100_mean = (0.5071, 0.4867, 0.4408)
    cifar100_std  = (0.2675, 0.2565, 0.2761)

    if autoaugment:
        transform_train = transforms.Compose([
            transforms.Resize(32),
            transforms.RandomCrop(32, padding=32//8),
            transforms.RandomHorizontalFlip(),
            AutoAugment(policy=AutoAugmentPolicy.CIFAR10),
            transforms.ToTensor(),
            # transforms.Normalize(mean, std),
            transforms.Normalize(cifar100_mean, cifar100_std),
            Cutout(size = 8),
        ])
    else: 
        transform_train = transforms.Compose([
            transforms.Resize(32),
            transforms.RandomCrop(32, padding=32//8),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            # transforms.Normalize(mean, std),
            transforms.Normalize(cifar100_mean, cifar100_std),
        ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(cifar100_mean, cifar100_std),
    ])

    train_dataset = datasets.CIFAR100(root=dataset_dir, train=True, download=True, transform=transform_train)
    test_dataset = datasets.CIFAR100(root=dataset_dir, train=False, download=False, transform=transform_test)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True, persistent_workers=True, prefetch_factor=2)
    test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False, num_workers=num_workers, pin_memory=True)

    return train_loader, test_loader

def get_cifar10_loaders_val(batch_size=128, num_workers=4, val_split=0.1, seed=42, model_type='WRN'):
    input_size = 224 if model_type == "ViT" else 32

    # Transforms
    transform_train = transforms.Compose([
        transforms.Resize(input_size),
        transforms.RandomCrop(input_size, padding=input_size//8),
        transforms.RandomHorizontalFlip(),
        # Cutout()
        transforms.ToTensor(),
        # transforms.Normalize(mean, std),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
    
    transform_val_test = transforms.Compose([
        transforms.Resize(input_size),
        transforms.ToTensor(),
        # transforms.Normalize(mean, std)
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    # Full train dataset (with augmentation) and validation dataset (no augmentation)
    train_full = datasets.CIFAR10(
        root=dataset_dir, train=True, download=True,
        transform=transform_train
    )
    val_full = datasets.CIFAR10(
        root=dataset_dir, train=True, download=False,
        transform=transform_val_test
    )
    # Test dataset
    test_dataset = datasets.CIFAR10(
        root=dataset_dir, train=False, download=True,
        transform=transform_val_test
    )

    # Create train/val split indices
    num_train = len(train_full)
    num_val = int(num_train * val_split)
    generator = torch.Generator().manual_seed(seed)
    # random_split uses generator internally, but we need same indices for both full and val_full
    indices = torch.randperm(num_train, generator=generator)
    val_indices = indices[:num_val].tolist()
    train_indices = indices[num_val:].tolist()

    # Subsets
    train_subset = Subset(train_full, train_indices)
    val_subset   = Subset(val_full,   val_indices)

    # DataLoaders
    train_loader = DataLoader(
        train_subset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True,
        persistent_workers=True,
        prefetch_factor=2
    )
    val_loader = DataLoader(
        val_subset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True,
        persistent_workers=True,
        prefetch_factor=2
    )
    test_loader = DataLoader(
        test_dataset,
        batch_size=1000,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True
    )

    return train_loader, val_loader, test_loader

def get_cifar100_loaders_val(batch_size=128, num_workers=4, val_split=0.1, seed=42):
    # CIFAR-100 mean/std
    cifar100_mean = (0.5071, 0.4867, 0.4408)
    cifar100_std  = (0.2675, 0.2565, 0.2761)

    # Transforms
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(cifar100_mean, cifar100_std),
    ])
    transform_val_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(cifar100_mean, cifar100_std),
    ])

    # Full train + val datasets
    train_full = datasets.CIFAR100(
        root=dataset_dir,
        train=True, download=True,
        transform=transform_train
    )
    val_full = datasets.CIFAR100(
        root=dataset_dir,
        train=True, download=False,
        transform=transform_val_test
    )
    # Test dataset
    test_dataset = datasets.CIFAR100(
        root=dataset_dir,
        train=False, download=True,
        transform=transform_val_test
    )

    # Create train/val split indices
    num_train = len(train_full)
    num_val   = int(num_train * val_split)
    gen       = torch.Generator().manual_seed(seed)
    indices   = torch.randperm(num_train, generator=gen)
    train_idx = indices[num_val:].tolist()
    val_idx   = indices[:num_val].tolist()

    train_subset = Subset(train_full, train_idx)
    val_subset   = Subset(val_full,   val_idx)

    # DataLoaders
    train_loader = DataLoader(
        train_subset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True,
        persistent_workers=True,
        prefetch_factor=2
    )
    val_loader = DataLoader(
        val_subset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True,
        persistent_workers=True,
        prefetch_factor=2
    )
    test_loader = DataLoader(
        test_dataset,
        batch_size=1000,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True
    )

    return train_loader, val_loader, test_loader


def split_labeled_unlabeled(labels, num_labeled, num_classes=10, seed=None):
    if seed is not None:
        np.random.seed(seed)
        random.seed(seed)
    labeled_idx = []
    unlabeled_idx = []
    for i in range(num_classes):
        idx = np.where(labels == i)[0]
        np.random.shuffle(idx)
        labeled_idx.extend(idx[:num_labeled // num_classes])
        unlabeled_idx.extend(idx[num_labeled // num_classes:])
    return np.array(labeled_idx), np.array(unlabeled_idx)

def set_seed(seed):
    if seed is not None:
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(seed)

def get_fixmatch_loaders(batch_size=64, num_workers=8, num_labeled=250, seed=5):
    set_seed(seed)

    cifar10_mean = [0.4914, 0.4822, 0.4465]
    cifar10_std = [0.2023, 0.1994, 0.2010]

    transform_labeled = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.RandomCrop(32, padding=4, padding_mode='reflect'),
        transforms.ToTensor(),
        transforms.Normalize(cifar10_mean, cifar10_std)
    ])
    transform_unlabeled = TransformFixMatch(mean=cifar10_mean, std=cifar10_std)
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(cifar10_mean, cifar10_std)
    ])

    base_dataset = datasets.CIFAR10(
        dataset_dir, train=True, download=True
    )
    test_dataset = datasets.CIFAR10(
        dataset_dir, train=False, download=True, transform=transform_test
    )

    labels = np.array(base_dataset.targets)
    labeled_idx, unlabeled_idx = split_labeled_unlabeled(
        labels, num_labeled, num_classes=10, seed=seed
    )

    labeled_dataset = CIFAR10SSL(
        base_dataset.data, labels, labeled_idx, 
        transform=transform_labeled
    )
    unlabeled_dataset = CIFAR10SSL(
        base_dataset.data, labels, unlabeled_idx,
        transform=transform_unlabeled
    )

    labeled_loader = DataLoader(
        labeled_dataset, batch_size=batch_size, shuffle=True,
        num_workers=num_workers, pin_memory=True, drop_last=True
    )
    unlabeled_loader = DataLoader(
        unlabeled_dataset, batch_size=batch_size * 7,
        shuffle=True, num_workers=num_workers, pin_memory=True, drop_last=True
    )
    test_loader = DataLoader(
        test_dataset, batch_size=1000, shuffle=False, num_workers=num_workers
    )

    return labeled_loader, unlabeled_loader, test_loader



def get_fixmatch_loaders_val(batch_size=64, num_workers=10, num_labeled=250, val_ratio=0.1, seed=5):
    set_seed(seed)
    cifar10_mean = [0.4914, 0.4822, 0.4465]
    cifar10_std = [0.2023, 0.1994, 0.2010]

    transform_labeled = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.RandomCrop(32, padding=4, padding_mode='reflect'),
        transforms.ToTensor(),
        transforms.Normalize(cifar10_mean, cifar10_std)
    ])
    transform_unlabeled = TransformFixMatch(mean=cifar10_mean, std=cifar10_std)
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(cifar10_mean, cifar10_std)
    ])

    base_dataset = datasets.CIFAR10(
        dataset_dir, train=True, download=True
    )
    labels = np.array(base_dataset.targets)
    num_samples = len(labels)

    np.random.seed(seed)
    indices = np.arange(num_samples)
    np.random.shuffle(indices)
    val_size = int(num_samples * val_ratio)
    val_indices = indices[:val_size]
    train_indices = indices[val_size:]

    train_labels = labels[train_indices]
    labeled_idx, unlabeled_idx = split_labeled_unlabeled(
        train_labels, num_labeled, num_classes=10, seed=seed
    )
    labeled_idx = train_indices[labeled_idx]
    unlabeled_idx = train_indices[unlabeled_idx]

    labeled_dataset = CIFAR10SSL(
        base_dataset.data, labels, labeled_idx, 
        transform=transform_labeled
    )
    unlabeled_dataset = CIFAR10SSL(
        base_dataset.data, labels, unlabeled_idx,
        transform=transform_unlabeled
    )
    val_dataset = Subset(
        datasets.CIFAR10(dataset_dir, train=True, download=True, transform=transform_test),
        val_indices
    )
    test_dataset = datasets.CIFAR10(
        dataset_dir, train=False, download=True, transform=transform_test
    )

    labeled_loader = DataLoader(
        labeled_dataset, batch_size=batch_size, shuffle=True,
        num_workers=num_workers, pin_memory=True, drop_last=True
    )
    unlabeled_loader = DataLoader(
        unlabeled_dataset, batch_size=batch_size * 7,
        shuffle=True, num_workers=num_workers, pin_memory=True, drop_last=True
    )
    val_loader = DataLoader(
        val_dataset, batch_size=1000, shuffle=False, num_workers=num_workers
    )
    test_loader = DataLoader(
        test_dataset, batch_size=1000, shuffle=False, num_workers=num_workers
    )

    return labeled_loader, unlabeled_loader, val_loader, test_loader

def get_fashion_mnist_loaders(batch_size=128, num_workers=4):
    transform = transforms.Compose([
        transforms.RandomCrop(28, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.2860,), (0.3530,))
    ])

    test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.2860,), (0.3530,))
    ])

    # Train dataset
    train_dataset = datasets.FashionMNIST(root=dataset_dir, train=True, download=True, transform=transform)
    
    # Test dataset  
    test_dataset = datasets.FashionMNIST(root=dataset_dir, train=False, download=True, transform=test_transform)

    # DataLoaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True, 
        num_workers=num_workers,
        pin_memory=True,
        persistent_workers=True,
        prefetch_factor=2
    )

    test_loader = DataLoader(
        test_dataset,
        batch_size=1000,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True
    )

    return train_loader, test_loader

def get_svhn_loaders(batch_size=128, num_workers=4):
    transform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                             std=[0.229, 0.224, 0.225])   
    ])

    test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                             std=[0.229, 0.224, 0.225])
    ])

    # Train dataset
    train_dataset = datasets.SVHN(root=dataset_dir, split ="train", download=True, transform=transform)
    
    # Test dataset  
    test_dataset = datasets.SVHN(root=dataset_dir, split ="test", download=True, transform=test_transform)

    # DataLoaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True, 
        num_workers=num_workers,
        pin_memory=True,
        persistent_workers=True,
        prefetch_factor=2
    )

    test_loader = DataLoader(
        test_dataset,
        batch_size=1000,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True
    )

    return train_loader, test_loader


def get_imagenet_loaders(
    data_dir: str = "/home/dataset/imagenet/",
    batch_size: int = 256,
    num_workers: int = 8,
    image_size: int = 224,
    prefetch_factor: int = 2
):
    """
    Returns train and validation DataLoaders for ImageNet-1K,
    with train transformations aligned with the SAM paper's "Basic Augmentations".

    Args:
        data_dir: path to ImageNet root (must contain 'train' and 'val' folders)
        batch_size: per-GPU batch size
        num_workers: number of DataLoader workers
        image_size: final crop size (default 224)
        prefetch_factor: DataLoader prefetch_factor (per worker)

    Returns:
        train_loader, val_loader
    """
    # ImageNet mean/std (RGB)
    mean = (0.485, 0.456, 0.406)
    std = (0.229, 0.224, 0.225)

    # Train transform implementing "basic augmentations" from the paper
    # This is a common pipeline for ImageNet.
    train_tf = transforms.Compose([
        transforms.RandomResizedCrop(image_size), # A standard approach that combines resize and random crop
        transforms.RandomHorizontalFlip(), # Horizontal flip
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
    ])

    # Validation preprocessing & normalization
    # Standard validation transform: resize to 256, then center crop to 224.
    val_tf = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(image_size),
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
    ])

    # Create Datasets
    train_ds = datasets.ImageNet(
        root=data_dir,
        split='train',
        transform=train_tf
    )
    val_ds = datasets.ImageNet(
        root=data_dir,
        split='val',
        transform=val_tf
    )

    # Create DataLoaders
    train_loader = DataLoader(
        train_ds,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True,
        prefetch_factor=prefetch_factor,
        persistent_workers=True if num_workers > 0 else False
    )
    val_loader = DataLoader(
        val_ds,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True
    )

    return train_loader, val_loader


def get_imagenet_datasets(
    data_dir: str = "/home/dataset/imagenet/",
    batch_size: int = 256,
    num_workers: int = 8,
    image_size: int = 224,
    prefetch_factor: int = 2,
    
):
    """
    Returns train and validation DataLoaders for ImageNet-1K,
    with train transformations aligned with the SAM paper's "Basic Augmentations".

    Args:
        data_dir: path to ImageNet root (must contain 'train' and 'val' folders)
        batch_size: per-GPU batch size
        num_workers: number of DataLoader workers
        image_size: final crop size (default 224)
        prefetch_factor: DataLoader prefetch_factor (per worker)

    Returns:
        train_loader, val_loader
    """
    # ImageNet mean/std (RGB)
    mean = (0.485, 0.456, 0.406)
    std = (0.229, 0.224, 0.225)

    # Train transform implementing "basic augmentations" from the paper
    # This is a common pipeline for ImageNet.
    train_tf = transforms.Compose([
        transforms.RandomResizedCrop(image_size), # A standard approach that combines resize and random crop
        transforms.RandomHorizontalFlip(), # Horizontal flip
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
    ])

    # Validation preprocessing & normalization
    # Standard validation transform: resize to 256, then center crop to 224.
    val_tf = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(image_size),
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
    ])

    # Create Datasets
    train_ds = datasets.ImageNet(
        root=data_dir,
        split='train',
        transform=train_tf
    )
    val_ds = datasets.ImageNet(
        root=data_dir,
        split='val',
        transform=val_tf
    )

    return train_ds, val_ds


def get_cifar10_dataset():
    """
    Returns the CIFAR-10 dataset with basic transformations.
    
    Returns:
        train_dataset, test_dataset
    """
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        AutoAugment(policy=AutoAugmentPolicy.CIFAR10),
        transforms.ToTensor(),
        # transforms.Normalize(mean, std),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        Cutout(),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        # transforms.Normalize(mean, std)
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    train_dataset = datasets.CIFAR10(root=dataset_dir, train=True, download=True, transform=transform_train)
    test_dataset = datasets.CIFAR10(root=dataset_dir, train=False, download=True, transform=transform_test)

    return train_dataset, test_dataset

def get_cifar100_dataset():
    """
    Returns the CIFAR-100 dataset with basic transformations.
    
    Returns:
        train_dataset, test_dataset
    """
    cifar100_mean = (0.5071, 0.4867, 0.4408)
    cifar100_std  = (0.2675, 0.2565, 0.2761)

    transform_train = transforms.Compose([
        transforms.Resize(32),
        transforms.RandomCrop(32, padding=32//8),
        transforms.RandomHorizontalFlip(),
        AutoAugment(policy=AutoAugmentPolicy.CIFAR10),
        transforms.ToTensor(),
        # transforms.Normalize(mean, std),
        transforms.Normalize(cifar100_mean, cifar100_std),
        Cutout(size = 8),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(cifar100_mean, cifar100_std),
    ])

    train_dataset = datasets.CIFAR100(root=dataset_dir, train=True, download=True, transform=transform_train)
    test_dataset = datasets.CIFAR100(root=dataset_dir, train=False, download=True, transform=transform_test)

    return train_dataset, test_dataset